{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DLinear model validation using offline batch inference\n",
    "\n",
    "<div align=\"left\">\n",
    "<a target=\"_blank\" href=\"https://console.anyscale.com/\"><img src=\"https://img.shields.io/badge/🚀 Run_on-Anyscale-9hf\"></a>&nbsp;\n",
    "<a href=\"https://github.com/anyscale/e2e-timeseries\" role=\"button\"><img src=\"https://img.shields.io/static/v1?label=&amp;message=View%20On%20GitHub&amp;color=586069&amp;logo=github&amp;labelColor=2f363d\"></a>&nbsp;\n",
    "</div>\n",
    "\n",
    "This tutorial demonstrates how to perform batch inference using the DLinear model and Ray Data.\n",
    "The process involves loading the model checkpoint, preparing the test data, running inference in batches, and evaluating the performance.\n",
    "\n",
    "Note that this notebook requires the pre-trained model artifacts that the previous \"Distributed training of a DLinear time-series model\" notebook generates.\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-timeseries/refs/heads/main/images/batch_inference.png\" width=800>\n",
    "\n",
    "\n",
    "The preceding figure illustrates how different blocks of data process concurrently at various stages of the pipeline. This parallel execution maximizes resource utilization and throughput.\n",
    "\n",
    "Note that this diagram is a simplification for various reasons:\n",
    "\n",
    "* Only one worker processes each data pipeline stage\n",
    "* Backpressure mechanisms may throttle upstream operators to prevent overwhelming downstream stages\n",
    "* Dynamic repartitioning often occurs as data moves through the pipeline, changing block counts and sizes\n",
    "* Available resources change as the cluster autoscales\n",
    "* System failures may disrupt the clean sequential flow shown in the diagram\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert\"> <b> Ray Data streaming execution</b> \n",
    "\n",
    "❌ **Traditional batch execution**, non-streaming like Spark without pipelining or SageMaker Batch Transform:\n",
    "- Reads the entire dataset into memory or a persistent intermediate format\n",
    "- Only then starts applying transformations, such as `.map`, `.filter`, etc.\n",
    "- Higher memory pressure and startup latency\n",
    "\n",
    "✅ **Streaming execution** with Ray Data:\n",
    "- Starts processing blocks as they load, without waiting for the entire dataset\n",
    "- Reduces memory footprint, preventing out-of-memory errors, and speeds up time to first output\n",
    "- Increases resource utilization by reducing idle time\n",
    "- Enables online-style inference pipelines with minimal latency\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-timeseries/refs/heads/main/images/streaming.gif\" width=700>\n",
    "\n",
    "**Note**: Ray Data operates as batch processing with streaming execution rather than a real-time stream processing engine like Flink or Kafka Streams. This approach proves especially useful for iterative ML workloads, ETL pipelines, and preprocessing before training or inference. Ray typically delivers a [**2-17x throughput improvement**](https://www.anyscale.com/blog/offline-batch-inference-comparing-ray-apache-spark-and-sagemaker#-results-of-throughput-from-experiments) over solutions like Spark and SageMaker Batch Transform.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable importing from e2e_timeseries module.\n",
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd())))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start by setting up the environment and imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import ray\n",
    "import torch\n",
    "\n",
    "os.environ[\"RAY_TRAIN_V2_ENABLED\"] = \"1\"\n",
    "\n",
    "import e2e_timeseries\n",
    "from e2e_timeseries.data_factory import data_provider\n",
    "from e2e_timeseries.metrics import metric\n",
    "from e2e_timeseries.model import DLinear"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the Ray cluster with the `e2e_timeseries` module, so that newly spawned workers can import it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ray.init(runtime_env={\"py_modules\": [e2e_timeseries]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, set up the DLinear model configuration as well as job configuration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the best checkpoint path from the metadata file created in the training notebook.\n",
    "best_checkpoint_metadata_fpath = os.path.join(\n",
    "    \"/mnt/cluster_storage/checkpoints\", \"best_checkpoint_path.txt\"\n",
    ")\n",
    "with open(best_checkpoint_metadata_fpath, \"r\") as f:\n",
    "    best_checkpoint_path = f.read().strip()\n",
    "\n",
    "config = {\n",
    "    \"checkpoint_path\": best_checkpoint_path,\n",
    "    \"num_data_workers\": 1,\n",
    "    \"features\": \"S\",\n",
    "    \"target\": \"OT\",\n",
    "    \"smoke_test\": False,\n",
    "    \"seq_len\": 96,\n",
    "    \"label_len\": 48,\n",
    "    \"pred_len\": 96,\n",
    "    \"individual\": False,\n",
    "    \"batch_size\": 64,\n",
    "    \"num_predictor_replicas\": 4,\n",
    "}\n",
    "\n",
    "\n",
    "def _process_config(config: dict) -> dict:\n",
    "    \"\"\"Helper function to process and update configuration.\"\"\"\n",
    "    # Configure encoder input size based on task type.\n",
    "    if config[\"features\"] == \"M\" or config[\"features\"] == \"MS\":\n",
    "        config[\"enc_in\"] = 7  # ETTh1 has 7 features when multi-dimensional prediction is enabled\n",
    "    else:\n",
    "        config[\"enc_in\"] = 1\n",
    "\n",
    "    # Ensure paths are absolute.\n",
    "    config[\"checkpoint_path\"] = os.path.abspath(config[\"checkpoint_path\"])\n",
    "\n",
    "    config[\"num_gpus_per_worker\"] = 1.0\n",
    "\n",
    "    config[\"train_only\"] = False  # Load test subset\n",
    "    return config\n",
    "\n",
    "\n",
    "# Set derived values.\n",
    "config = _process_config(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data ingest\n",
    "\n",
    "First, load the test dataset as a Ray Data Dataset. Use `.show(1)` to trigger the execution for a single row,\n",
    "because Ray Data is lazily evaluates datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ray.init(ignore_reinit_error=True)\n",
    "\n",
    "print(\"Loading test data...\")\n",
    "ds = data_provider(config, flag=\"test\")\n",
    "ds.show(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This cell defines the Predictor class. It loads the trained DLinear model from a checkpoint and\n",
    "processes input batches to produce predictions. The __call__ method performs inference\n",
    "on a given batch of NumPy arrays.\n",
    "\n",
    "Ray Data's actor-based processing enables loading the model weights and transferring them to GPU only once and reusing them across batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Predictor:\n",
    "    \"\"\"Actor class for performing inference with the DLinear model.\"\"\"\n",
    "\n",
    "    def __init__(self, checkpoint_path: str, config: dict):\n",
    "        self.config = config\n",
    "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "        # Load model from checkpoint.\n",
    "        self.model = DLinear(config).float()\n",
    "        checkpoint = torch.load(checkpoint_path, map_location=self.device)\n",
    "        self.model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
    "        self.model.to(self.device)\n",
    "        self.model.eval()\n",
    "\n",
    "    def __call__(self, batch: dict[str, np.ndarray]) -> dict:\n",
    "        \"\"\"Process a batch of data for inference (numpy batch format).\"\"\"\n",
    "        # Convert input batch to tensor.\n",
    "        batch_x = torch.from_numpy(batch[\"x\"]).float().to(self.device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            outputs = self.model(batch_x)  # Shape (N, pred_len, features_out)\n",
    "\n",
    "        # Determine feature dimension based on config.\n",
    "        f_dim = -1 if self.config[\"features\"] == \"MS\" else 0\n",
    "        outputs = outputs[:, -self.config[\"pred_len\"] :, f_dim:]\n",
    "        outputs_np = outputs.cpu().numpy()\n",
    "\n",
    "        # Extract the target part from the batch.\n",
    "        batch_y = batch[\"y\"]\n",
    "        batch_y_target = batch_y[:, -self.config[\"pred_len\"] :]\n",
    "\n",
    "        return {\"predictions\": outputs_np, \"targets\": batch_y_target}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = ds.map_batches(\n",
    "    Predictor,\n",
    "    fn_constructor_kwargs={\"checkpoint_path\": config[\"checkpoint_path\"], \"config\": config},\n",
    "    batch_size=config[\"batch_size\"],\n",
    "    concurrency=config[\"num_predictor_replicas\"],\n",
    "    num_gpus=config[\"num_gpus_per_worker\"],\n",
    "    batch_format=\"numpy\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, perform minor post-processing to get the results in the desired dimensions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def postprocess_items(item: dict) -> dict:\n",
    "    # Squeeze singleton dimensions for predictions and targets if necessary.\n",
    "    if item[\"predictions\"].shape[-1] == 1:\n",
    "        item[\"predictions\"] = item[\"predictions\"].squeeze(-1)\n",
    "    if item[\"targets\"].shape[-1] == 1:\n",
    "        item[\"targets\"] = item[\"targets\"].squeeze(-1)\n",
    "    return item\n",
    "\n",
    "\n",
    "ds = ds.map(postprocess_items)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, execute all of these lazy steps and materialize them into memory using `take_all()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Trigger the lazy execution of the entire Ray pipeline.\n",
    "all_results = ds.take_all()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the results are in memory, calculate some validation metrics for the trained DLinear model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate predictions and targets from all batches.\n",
    "all_predictions = np.concatenate([item[\"predictions\"] for item in all_results], axis=0)\n",
    "all_targets = np.concatenate([item[\"targets\"] for item in all_results], axis=0)\n",
    "\n",
    "# Compute evaluation metrics.\n",
    "mae, mse, rmse, mape, mspe, rse = metric(all_predictions, all_targets)\n",
    "\n",
    "print(\"\\n--- Test Results ---\")\n",
    "print(f\"MSE: {mse:.3f}\")\n",
    "print(f\"MAE: {mae:.3f}\")\n",
    "print(f\"RMSE: {rmse:.3f}\")\n",
    "print(f\"MAPE: {mape:.3f}\")\n",
    "print(f\"MSPE: {mspe:.3f}\")\n",
    "print(f\"RSE: {rse:.3f}\")\n",
    "\n",
    "print(\"\\nOffline inference finished!\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
